{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Importing Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:22.096958Z",
     "start_time": "2020-05-04T03:58:22.093325Z"
    }
   },
   "outputs": [],
   "source": [
    "# Data Visualization\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Data Processing\n",
    "import pandas as pd\n",
    "\n",
    "# NLP \n",
    "import re\n",
    "\n",
    "# ML\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score, accuracy_score, confusion_matrix\n",
    "\n",
    "# BERT\n",
    "from simpletransformers.classification import MultiLabelClassificationModel\n",
    "\n",
    "import logging\n",
    "import warnings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocessing Data for Bidirectional Multilabel Classification using BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:22.686027Z",
     "start_time": "2020-05-04T03:58:22.676754Z"
    }
   },
   "outputs": [],
   "source": [
    "df = pd.read_pickle(\"../data/preprocessed_dataset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:22.775761Z",
     "start_time": "2020-05-04T03:58:22.762165Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>questionTitle</th>\n",
       "      <th>questionText</th>\n",
       "      <th>questionLink</th>\n",
       "      <th>topic</th>\n",
       "      <th>answerText</th>\n",
       "      <th>upvotes</th>\n",
       "      <th>views</th>\n",
       "      <th>root_topic</th>\n",
       "      <th>root_multi_label</th>\n",
       "      <th>reflection</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>How can I keep a long distance relationship go...</td>\n",
       "      <td>We weren't long distance until he joined the m...</td>\n",
       "      <td>https://counselchat.com/questions/how-can-i-ke...</td>\n",
       "      <td>relationships</td>\n",
       "      <td>Hello. You are asking a very good question abo...</td>\n",
       "      <td>9</td>\n",
       "      <td>481</td>\n",
       "      <td>family_conflicts</td>\n",
       "      <td>[1, 0, 1]</td>\n",
       "      <td>You are asking a very good question about how ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>How can I ask my boyfriend about who he's text...</td>\n",
       "      <td>We've been in a long distance relationship for...</td>\n",
       "      <td>https://counselchat.com/questions/how-can-i-as...</td>\n",
       "      <td>relationships</td>\n",
       "      <td>I agree with Sherry that in a close intimate r...</td>\n",
       "      <td>9</td>\n",
       "      <td>472</td>\n",
       "      <td>family_conflicts</td>\n",
       "      <td>[1, 0, 0]</td>\n",
       "      <td>I agree with Sherry that in a close intimate r...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Why do I feel like I'm always wrong in everyth...</td>\n",
       "      <td>My wife is always accusing me of cheating and ...</td>\n",
       "      <td>https://counselchat.com/questions/why-do-i-fee...</td>\n",
       "      <td>workplace-relationships</td>\n",
       "      <td>Hello. That must be very frustrating for you t...</td>\n",
       "      <td>9</td>\n",
       "      <td>268</td>\n",
       "      <td>others</td>\n",
       "      <td>[1, 0, 1]</td>\n",
       "      <td>That must be very frustrating for you to feel ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Why do I feel sad all the time?</td>\n",
       "      <td>I just feel sad all the time and I don't like ...</td>\n",
       "      <td>https://counselchat.com/questions/why-do-i-fee...</td>\n",
       "      <td>family-conflict</td>\n",
       "      <td>Hello,While one can be sad from time to time, ...</td>\n",
       "      <td>9</td>\n",
       "      <td>264</td>\n",
       "      <td>family_conflicts</td>\n",
       "      <td>[1, 1, 0]</td>\n",
       "      <td>If you feel sad on most days, it is worthwhile...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>The Underlying Causes of Being Possessive</td>\n",
       "      <td>I am extremely possessive in my relationships ...</td>\n",
       "      <td>https://counselchat.com/questions/the-underlyi...</td>\n",
       "      <td>social-relationships</td>\n",
       "      <td>Hi there. Its great you are able to realize th...</td>\n",
       "      <td>7</td>\n",
       "      <td>224</td>\n",
       "      <td>others</td>\n",
       "      <td>[0, 0, 1]</td>\n",
       "      <td>Its great you are able to realize there are ot...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       questionTitle  \\\n",
       "0  How can I keep a long distance relationship go...   \n",
       "1  How can I ask my boyfriend about who he's text...   \n",
       "2  Why do I feel like I'm always wrong in everyth...   \n",
       "3                    Why do I feel sad all the time?   \n",
       "4          The Underlying Causes of Being Possessive   \n",
       "\n",
       "                                        questionText  \\\n",
       "0  We weren't long distance until he joined the m...   \n",
       "1  We've been in a long distance relationship for...   \n",
       "2  My wife is always accusing me of cheating and ...   \n",
       "3  I just feel sad all the time and I don't like ...   \n",
       "4  I am extremely possessive in my relationships ...   \n",
       "\n",
       "                                        questionLink                    topic  \\\n",
       "0  https://counselchat.com/questions/how-can-i-ke...            relationships   \n",
       "1  https://counselchat.com/questions/how-can-i-as...            relationships   \n",
       "2  https://counselchat.com/questions/why-do-i-fee...  workplace-relationships   \n",
       "3  https://counselchat.com/questions/why-do-i-fee...          family-conflict   \n",
       "4  https://counselchat.com/questions/the-underlyi...     social-relationships   \n",
       "\n",
       "                                          answerText  upvotes  views  \\\n",
       "0  Hello. You are asking a very good question abo...        9    481   \n",
       "1  I agree with Sherry that in a close intimate r...        9    472   \n",
       "2  Hello. That must be very frustrating for you t...        9    268   \n",
       "3  Hello,While one can be sad from time to time, ...        9    264   \n",
       "4  Hi there. Its great you are able to realize th...        7    224   \n",
       "\n",
       "         root_topic root_multi_label  \\\n",
       "0  family_conflicts        [1, 0, 1]   \n",
       "1  family_conflicts        [1, 0, 0]   \n",
       "2            others        [1, 0, 1]   \n",
       "3  family_conflicts        [1, 1, 0]   \n",
       "4            others        [0, 0, 1]   \n",
       "\n",
       "                                          reflection  \n",
       "0  You are asking a very good question about how ...  \n",
       "1  I agree with Sherry that in a close intimate r...  \n",
       "2  That must be very frustrating for you to feel ...  \n",
       "3  If you feel sad on most days, it is worthwhile...  \n",
       "4  Its great you are able to realize there are ot...  "
      ]
     },
     "execution_count": 152,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:22.846839Z",
     "start_time": "2020-05-04T03:58:22.840512Z"
    }
   },
   "outputs": [],
   "source": [
    "df[\"text\"] = df.questionText + ' ' + df.questionTitle\n",
    "df[\"text\"] = df[\"text\"].astype(str)\n",
    "df[\"labels\"] = df.root_multi_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:22.928924Z",
     "start_time": "2020-05-04T03:58:22.924283Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "questionTitle       object\n",
       "questionText        object\n",
       "questionLink        object\n",
       "topic               object\n",
       "answerText          object\n",
       "upvotes              int64\n",
       "views                int64\n",
       "root_topic          object\n",
       "root_multi_label    object\n",
       "reflection          object\n",
       "text                object\n",
       "labels              object\n",
       "dtype: object"
      ]
     },
     "execution_count": 154,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:26.785370Z",
     "start_time": "2020-05-04T03:58:23.007488Z"
    }
   },
   "outputs": [],
   "source": [
    "model = MultiLabelClassificationModel('bert', 'bert-base-uncased', num_labels=3, use_cuda=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:26.795796Z",
     "start_time": "2020-05-04T03:58:26.787736Z"
    }
   },
   "outputs": [],
   "source": [
    "train_df, test_df = train_test_split(df, test_size=0.3, random_state =333)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:26.803960Z",
     "start_time": "2020-05-04T03:58:26.798854Z"
    }
   },
   "outputs": [],
   "source": [
    "train_df = train_df[['text', 'labels']]\n",
    "test_df = test_df[['text', 'labels']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:26.811119Z",
     "start_time": "2020-05-04T03:58:26.806160Z"
    }
   },
   "outputs": [],
   "source": [
    "train_df = train_df.reset_index()\n",
    "train_df.drop(['index'], axis = 1, inplace = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T03:58:00.596063Z",
     "start_time": "2020-05-04T03:58:00.591450Z"
    }
   },
   "outputs": [],
   "source": [
    "test_df = test_df.reset_index()\n",
    "test_df.drop(['index'], axis = 1, inplace = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T04:53:37.129767Z",
     "start_time": "2020-05-04T03:59:56.043721Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converting to features started. Cache is not used.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b18f83ba7334dc398fc5c2402b04acc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=581), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "843e8bafe2af457f9d3bc5004acb58c5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=10, style=ProgressStyle(description_width='initia…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "383b593388644e78946ef0c096773e6c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.347322"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fe6e82e3fb5c40b7a59e9e942b2f1000",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.546854"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/kaushik-shakkari/anaconda3/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:224: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.\n",
      "  warnings.warn(\"To get the last learning rate computed by the scheduler, \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.590284"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e8d6975fadd4514ba2c7a0cc2a905a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.150954"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ec322f90f39f4fbb9d374ca19616c437",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.246771"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5ee9665703c49a8a1159d0896de459c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.045022"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c35e7fd5ed5e439b8bd00819698a4939",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.016176"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a6b07088c5254d36bbad88177a6d7414",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.050648"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d1a6250d9be7432c9ca451f5e5ac8ecc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.014145"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "78a5b69054cf41e6b6d43533112ce1ac",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.007887"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "840b615793e34927b6036fef26a9b400",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Current iteration', max=42, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running loss: 0.005918Training of bert model complete. Saved to outputs/.\n"
     ]
    }
   ],
   "source": [
    "model.train_model(train_df, args={'learning_rate':1e-4, 'num_train_epochs': 10, 'reprocess_input_data': True, 'overwrite_output_dir': True,\"train_batch_size\": 14})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T08:42:40.814553Z",
     "start_time": "2020-05-04T08:42:06.820494Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "72fb36bd4e2d4af297c31ff149d2071e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=250), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c44e6565bc140dc89c268b97f2d749f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result, model_outputs, wrong_predictions = model.eval_model(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-04T08:42:40.825332Z",
     "start_time": "2020-05-04T08:42:40.817193Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'LRAP': 0.9146666666666668, 'eval_loss': 0.7711860002018511}"
      ]
     },
     "execution_count": 162,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
